from utils_general import *
from utils_methods import *

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='CIFAR10', type=str)
parser.add_argument('--model', default='Resnet18', type=str)
parser.add_argument('--non-iid', default=False, action='store_true')
parser.add_argument('--rule-arg', default=0.6, type=float)
parser.add_argument('--act_prob', default=0.1, type=float)
parser.add_argument('--method', default='FedAvg', type=str)
parser.add_argument('--n_client', default=100, type=int)
parser.add_argument('--rounds', default=500, type=int)
parser.add_argument('--local-epochs', default=5, type=int)
parser.add_argument('--alpha', default=0.1, type=float)
parser.add_argument('--alpha-coef', default=0.1, type=float)
parser.add_argument('--weight_decay', default=0.001, type=float)
parser.add_argument('--local-learning-rate', default=0.1, type=float)
parser.add_argument('--global-learning-rate', default=1.0, type=float)
parser.add_argument('--lr-decay', default=0.9995, type=float)
parser.add_argument('--sch-gamma', default=1.0, type=float)
parser.add_argument('--test-per', default=1, type=int)
parser.add_argument('--batchsize', default=50, type=int)
parser.add_argument('--save_period', default=10000, type=int)
parser.add_argument('--seed', default=20, type=int)
parser.add_argument('--rho', default=0.1, type=float)
args = parser.parse_args()
print(args)

# Dataset initialization
data_path = './'

# Common hyperparameters
com_amount = args.rounds
save_period = args.save_period
weight_decay = args.weight_decay
batch_size = args.batchsize
act_prob = args.act_prob
model_name = args.model
lr_decay_per_round = args.lr_decay
n_client = args.n_client

# Generate IID or Dirichlet distribution
if args.non_iid is False:
    data_obj = DatasetObject(dataset=args.dataset, n_client=n_client, seed=args.seed, unbalanced_sgm=0, rule='iid',
                                 data_path=data_path)
else:
    data_obj = DatasetObject(dataset=args.dataset, n_client=n_client, seed=args.seed, unbalanced_sgm=0, rule='Drichlet',
                                 rule_arg=args.rule_arg, data_path=data_path)

# Model function
model_func = lambda: client_model(model_name)
torch.manual_seed(0)
init_model = model_func()

if __name__=='__main__':

    if args.method == 'FedAvg':
        epoch = args.local_epochs
        learning_rate = args.local_learning_rate
        test_per = args.test_per
        train_FedAvg(data_obj=data_obj,act_prob=act_prob, learning_rate=learning_rate,
                    batch_size=batch_size, epoch=epoch, com_amount=com_amount, test_per=test_per,
                    weight_decay=weight_decay, model_func=model_func, init_model=init_model,
                    sch_step=1, sch_gamma=1, rand_seed=0, lr_decay_per_round=lr_decay_per_round)

    elif args.method == 'SCAFFOLD':
        epoch = args.local_epochs
        n_data_per_client = np.concatenate(data_obj.client_x, axis=0).shape[0] / n_client
        n_iter_per_epoch = np.ceil(n_data_per_client / batch_size)
        n_minibatch = (epoch * n_iter_per_epoch).astype(np.int64)
        learning_rate = args.local_learning_rate
        test_per = args.test_per
        train_SCAFFOLD(data_obj=data_obj, act_prob=act_prob,
            learning_rate=learning_rate, batch_size=batch_size, n_minibatch=n_minibatch,
            com_amount=com_amount, test_per=test_per, weight_decay=weight_decay,
            model_func=model_func, init_model=init_model,
            sch_step=1, sch_gamma=1, rand_seed=0, lr_decay_per_round=lr_decay_per_round)

    
    elif args.method == 'FedSpeed':
        epoch = args.local_epochs
        alpha_coef = args.alpha_coef
        learning_rate = args.local_learning_rate
        test_per = args.test_per
        train_FedSpeed(data_obj=data_obj, act_prob=act_prob, learning_rate=learning_rate,
                    batch_size=batch_size,  epoch=epoch, com_amount=com_amount, test_per=test_per,
                    weight_decay=weight_decay, model_func=model_func, init_model=init_model,
                    alpha_coef=alpha_coef, sch_step=1, sch_gamma=args.sch_gamma, rho=args.rho,
                    rand_seed=0, lr_decay_per_round=lr_decay_per_round)